{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "eb4f5fec-1a5c-4b6f-b583-fb369472e94b",
   "metadata": {},
   "source": [
    "# PSI On SPU"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4172c7c4",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f9b632d8-b12f-44a1-8a75-5d9c0a704a38",
   "metadata": {},
   "source": [
    "PSI([Private Set Intersection](https://en.wikipedia.org/wiki/Private_set_intersection)) is a cryptographic technique that allows two parties holding sets to compare encrypted versions of these sets in order to compute the intersection. In this scenario, neither party reveals anything to the counterparty except for the elements in the intersection.\n",
    "\n",
    "In SecretFlow, SPU device supports three PSI protocol:\n",
    "\n",
    "- [ECDH](https://ieeexplore.ieee.org/document/6234849/)：semi-honest, based on public key encryption, suitable for small datasets.\n",
    "- [KKRT](https://eprint.iacr.org/2016/799.pdf)：semi-host, based on cuckoo hashing and OT extension, suitable for large datasets.\n",
    "- [BC22PCG](https://eprint.iacr.org/2022/334): semi-host, psi from pseudorandom correlation generators.\n",
    "\n",
    "Before we start, we need to initialize the environment. The following three nodes `alice`, `bob`, and `carol` are created on a single machine to simulate multiple participants."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3d7c4fa2-ea20-4e0d-b1ad-648cce23e729",
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(['alice', 'bob', 'carol'], address='local')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "00a798bd",
   "metadata": {},
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5ed0a08b-3aa4-4fa6-9e1d-0caba207bdf5",
   "metadata": {},
   "source": [
    "## Preparing dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b4c16f07-1c67-4bad-af70-d8a4fe9266f3",
   "metadata": {},
   "source": [
    "First, we need a dataset for constructing vertical partitioned scenarios. For simplicity, we use [iris](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html) dataset here. We add two columns to it for subsequent single-column and multi-column intersection demonstrations\n",
    "\n",
    "- uid：Sample unique ID.\n",
    "- month：Simulate a scenario where samples are generated monthly. The first 50% of the samples are generated in January, and the last 50% of the samples are generated in February."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "31f0a010-0a2e-4ee2-996a-169d7cb2731d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>uid</th>\n",
       "      <th>month</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "      <td>Jan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>1</td>\n",
       "      <td>Jan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>2</td>\n",
       "      <td>Jan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>3</td>\n",
       "      <td>Jan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>4</td>\n",
       "      <td>Jan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>145</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>146</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>147</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>148</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>149</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                  5.1               3.5                1.4               0.2   \n",
       "1                  4.9               3.0                1.4               0.2   \n",
       "2                  4.7               3.2                1.3               0.2   \n",
       "3                  4.6               3.1                1.5               0.2   \n",
       "4                  5.0               3.6                1.4               0.2   \n",
       "..                 ...               ...                ...               ...   \n",
       "145                6.7               3.0                5.2               2.3   \n",
       "146                6.3               2.5                5.0               1.9   \n",
       "147                6.5               3.0                5.2               2.0   \n",
       "148                6.2               3.4                5.4               2.3   \n",
       "149                5.9               3.0                5.1               1.8   \n",
       "\n",
       "     uid month  \n",
       "0      0   Jan  \n",
       "1      1   Jan  \n",
       "2      2   Jan  \n",
       "3      3   Jan  \n",
       "4      4   Jan  \n",
       "..   ...   ...  \n",
       "145  145   Feb  \n",
       "146  146   Feb  \n",
       "147  147   Feb  \n",
       "148  148   Feb  \n",
       "149  149   Feb  \n",
       "\n",
       "[150 rows x 6 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_iris\n",
    "\n",
    "data, target = load_iris(return_X_y=True, as_frame=True)\n",
    "data['uid'] = np.arange(len(data)).astype('str')\n",
    "data['month'] = ['Jan'] * 75 + ['Feb'] * 75\n",
    "\n",
    "data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1bb263ff-e8a3-4066-ab30-ea01030a9f18",
   "metadata": {},
   "source": [
    "In the actual scenario, the sample data is provided by each participant, and the fields for intersection need to be agreed in advance:\n",
    "\n",
    "- The intersection field can be single or multiple.\n",
    "- The intersection field must be unique. If there is a duplicate, it needs to be deduplicated in advance.\n",
    "\n",
    "For example, The following is the data provided by alice for PSI intersection, the intersection field is `uid` and `month`，we can see that [1, 'Jan'] is duplicated.\n",
    "```\n",
    "alice.csv\n",
    "---------\n",
    "uid   month   c0\n",
    "1     Jan     5.8\n",
    "2     Jan     5.4\n",
    "1     Jan     5.8\n",
    "1     Feb     7.4\n",
    "```\n",
    "The data after deduplication is\n",
    "```\n",
    "alice.csv\n",
    "---------\n",
    "uid   month   c0\n",
    "1     Jan     5.8\n",
    "2     Jan     5.4\n",
    "1     Feb     7.4\n",
    "```\n",
    "We randomly sample the iris data three times to simulate the data provided by `alice`, `bob`, and `carol`, and the three data are in an unaligned state.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "037542dd-7945-4665-9d6a-7d805ea52b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.makedirs('.data', exist_ok=True)\n",
    "da, db, dc = data.sample(frac=0.9), data.sample(frac=0.8), data.sample(frac=0.7)\n",
    "\n",
    "da.to_csv('.data/alice.csv', index=False)\n",
    "db.to_csv('.data/bob.csv', index=False)\n",
    "dc.to_csv('.data/carol.csv', index=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d54451db-c9ba-45ca-877b-06619e03215f",
   "metadata": {},
   "source": [
    "## Two parties PSI"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7c12512f-4889-4f71-a539-e5bb7f56a9a5",
   "metadata": {},
   "source": [
    "We virtualize three logical devices on the physical device:\n",
    "\n",
    "- alice, bob: PYU device, responsible for the local plaintext computation of the participant.\n",
    "- spu：SPUdevice, consists of alice and bob, responsible for the ciphertext calculation of the two parties."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ff729adb-f89a-499d-999f-6d884f2203e0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n",
    "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fb2b161a-5aa9-4d06-b866-48b25273e48b",
   "metadata": {},
   "source": [
    "### Single-column PSI"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c12444e3-1da0-426e-add2-609770f8f259",
   "metadata": {},
   "source": [
    "Next, we use `uid` to intersect the two data, SPU provide `psi_csv` which take the csv file as input and generate the csv file after the intersection. The default protocol is KKRT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dec15656-51c1-4498-9f78-b2bcfd5168fb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "input_path = {alice: '.data/alice.csv', bob: '.data/bob.csv'}\n",
    "output_path = {alice: '.data/alice_psi.csv', bob: '.data/bob_psi.csv'}\n",
    "spu.psi_csv('uid', input_path, output_path, 'alice')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1ef80302-548f-4277-a460-fee0a07b7728",
   "metadata": {},
   "source": [
    "To check the correctness of the results, we use [pandas.DataFrame.join](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.join.html) to inner join da and db. It can be seen that the two data have been aligned according to `uid` and sorted according to their lexicographical order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ec091c3a-83bc-41f4-85d5-351c9d98c643",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>uid</th>\n",
       "      <th>month</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.3</td>\n",
       "      <td>3.3</td>\n",
       "      <td>6.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>100</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5.8</td>\n",
       "      <td>2.7</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.9</td>\n",
       "      <td>101</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.1</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.9</td>\n",
       "      <td>2.1</td>\n",
       "      <td>102</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.9</td>\n",
       "      <td>5.6</td>\n",
       "      <td>1.8</td>\n",
       "      <td>103</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.8</td>\n",
       "      <td>2.2</td>\n",
       "      <td>104</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>101</th>\n",
       "      <td>5.6</td>\n",
       "      <td>2.7</td>\n",
       "      <td>4.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>94</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>102</th>\n",
       "      <td>5.7</td>\n",
       "      <td>2.9</td>\n",
       "      <td>4.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>96</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>6.2</td>\n",
       "      <td>2.9</td>\n",
       "      <td>4.3</td>\n",
       "      <td>1.3</td>\n",
       "      <td>97</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>104</th>\n",
       "      <td>5.1</td>\n",
       "      <td>2.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.1</td>\n",
       "      <td>98</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>5.7</td>\n",
       "      <td>2.8</td>\n",
       "      <td>4.1</td>\n",
       "      <td>1.3</td>\n",
       "      <td>99</td>\n",
       "      <td>Feb</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>106 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                  6.3               3.3                6.0               2.5   \n",
       "1                  5.8               2.7                5.1               1.9   \n",
       "2                  7.1               3.0                5.9               2.1   \n",
       "3                  6.3               2.9                5.6               1.8   \n",
       "4                  6.5               3.0                5.8               2.2   \n",
       "..                 ...               ...                ...               ...   \n",
       "101                5.6               2.7                4.2               1.3   \n",
       "102                5.7               2.9                4.2               1.3   \n",
       "103                6.2               2.9                4.3               1.3   \n",
       "104                5.1               2.5                3.0               1.1   \n",
       "105                5.7               2.8                4.1               1.3   \n",
       "\n",
       "     uid month  \n",
       "0    100   Feb  \n",
       "1    101   Feb  \n",
       "2    102   Feb  \n",
       "3    103   Feb  \n",
       "4    104   Feb  \n",
       "..   ...   ...  \n",
       "101   94   Feb  \n",
       "102   96   Feb  \n",
       "103   97   Feb  \n",
       "104   98   Feb  \n",
       "105   99   Feb  \n",
       "\n",
       "[106 rows x 6 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = da.join(db.set_index('uid'), on='uid', how='inner', rsuffix='_bob', sort=True)\n",
    "expected = df[da.columns].astype({'uid': 'int64'}).reset_index(drop=True)\n",
    "\n",
    "da_psi = pd.read_csv('.data/alice_psi.csv')\n",
    "db_psi = pd.read_csv('.data/bob_psi.csv')\n",
    "\n",
    "pd.testing.assert_frame_equal(da_psi, expected)\n",
    "pd.testing.assert_frame_equal(db_psi, expected)\n",
    "\n",
    "da_psi"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4e1015b1-4062-4f40-8c5b-3b3e346cd4d2",
   "metadata": {},
   "source": [
    "### Multi-columns PSI"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "26bdf77c-dd58-4053-a3f4-f69659c8c96e",
   "metadata": {},
   "source": [
    "We can also use multiple fields to intersect, the following demonstrates the use of `uid` and `month` to intersect two data. In terms of implementation, multiple fields are concatenated into a string, so please ensure that there is no duplication of the multi-column composite primary key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "db24b582-ef58-4791-89d5-074be619d23f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "spu.psi_csv(['uid', 'month'], input_path, output_path, 'alice')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f7b34956-48c7-4e0f-bb6a-013508866267",
   "metadata": {},
   "source": [
    "Similarly, we use pandas.DataFrame.join to verify the correctness of the result, we can see that the two data have been aligned according to `uid` and `month`, and sorted according to their lexicographical order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aebb3a76-977b-4856-b17d-066fd43230c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = da.join(db.set_index(['uid', 'month']), on=['uid', 'month'], how='inner', rsuffix='_bob', sort=True)\n",
    "expected = df[da.columns].astype({'uid': 'int64'}).reset_index(drop=True)\n",
    "\n",
    "da_psi = pd.read_csv('.data/alice_psi.csv')\n",
    "db_psi = pd.read_csv('.data/bob_psi.csv')\n",
    "\n",
    "pd.testing.assert_frame_equal(da_psi, expected)\n",
    "pd.testing.assert_frame_equal(db_psi, expected)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f87f2dc4-c6c8-409a-b83d-fa15661def83",
   "metadata": {},
   "source": [
    "## Three parties PSI"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "82d884a3-4c6b-47ac-8526-ff486e89bcd4",
   "metadata": {},
   "source": [
    "Next, we add a third-party `carol`, and create a PYU device for it, as well as an SPU device built by the third party."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7eb7050f-70e4-4dc6-9f07-b19b8aeb69bc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "carol = sf.PYU('carol')\n",
    "spu_3pc = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob', 'carol']))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "eee3d13d-c648-4072-9ebd-c50787701262",
   "metadata": {},
   "source": [
    "Then, use `uid` and `month` as the composite primary key to perform a three-way negotiation. It should be noted that the three-way negotiation only supports the ECDH protocol for the time being."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "91f58d96-ef7c-411d-9cc4-524ef1c0dd44",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "input_path = {alice: '.data/alice.csv', bob: '.data/bob.csv', carol: '.data/carol.csv'}\n",
    "output_path = {alice: '.data/alice_psi.csv', bob: '.data/bob_psi.csv', carol: '.data/carol_psi.csv'}\n",
    "spu_3pc.psi_csv(['uid', 'month'], input_path, output_path, 'alice', protocol='ECDH_PSI_3PC')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "17461ac1-1409-45ef-a513-ecfe6482feb8",
   "metadata": {},
   "source": [
    "Similarly, we use pandas.DataFrame.join to verify the correctness of the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b6ea04f0-6b8a-40d7-b07a-06e926884261",
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = ['uid', 'month']\n",
    "df = da.join(db.set_index(keys), on=keys, how='inner', rsuffix='_bob', sort=True).join(\n",
    "    dc.set_index(keys), on=keys, how='inner', rsuffix='_carol', sort=True)\n",
    "expected = df[da.columns].astype({'uid': 'int64'}).reset_index(drop=True)\n",
    "\n",
    "da_psi = pd.read_csv('.data/alice_psi.csv')\n",
    "db_psi = pd.read_csv('.data/bob_psi.csv')\n",
    "dc_psi = pd.read_csv('.data/carol_psi.csv')\n",
    "\n",
    "pd.testing.assert_frame_equal(da_psi, expected)\n",
    "pd.testing.assert_frame_equal(db_psi, expected)\n",
    "pd.testing.assert_frame_equal(dc_psi, expected)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6f4bb1ed-2530-46c4-b540-8c779dd93439",
   "metadata": {},
   "source": [
    "## What's Next"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "55afbe9a-2855-41d6-b867-ec61c4ba4d20",
   "metadata": {},
   "source": [
    "OK! Through this tutorial, we have seen how to do two-party and three-party data intersections via SPU. After completing the data intersection, we can perform machine learning modeling on the aligned dataset.\n",
    "\n",
    "- [Logistic Regression On SPU](./lr_with_spu.ipynb): Logistic regression modeling on SPU using JAX.\n",
    "- [Neural Network on SPU](./nn_with_spu.ipynb): Neural Network Modeling on SPU with JAX.\n",
    "- [Basic Split Learning](./split_learning_gnn.ipynb): Neural Network Modeling with TensorFlow and Split Learning."
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "885f952f3a9a8d4529917f03b3eb9071e6c6f6c8d203379e2eea819d4a2f2375"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('secretflow')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
